
from dataclasses import dataclass
from typing import List, Tuple, Dict, Any
import numpy as np
from .gates import ThetaLadder, KappaLadder, StructuralGates, CRA
from .lex import ratio_lex_score, compare_ratio_lex
from .tie_pf import tie_break_pf_born
from .lints import Lints
from .scenes import SceneConfig, profile_from_roi, find_peaks_trough

@dataclass
class RunManifest:
    theta: ThetaLadder
    kappa: KappaLadder
    structural: StructuralGates
    cra: CRA
    lints: Lints
    seed: int
    c_units: float = 1.0

@dataclass
class Result:
    accepted_index: int
    accepted_score: Tuple[int,int,int]
    tie_weights: List[float]
    tie_count: int
    rng_used: bool
    diagnostics: Dict[str,Any]

class PresentActEngine:
    def __init__(self, scene: SceneConfig, manifest: RunManifest):
        self.scene = scene
        self.m = manifest
        np.random.seed(self.m.seed)

    def _neighbors4(self, x, y):
        return [(x+1,y),(x-1,y),(x,y+1),(x,y-1)]

    def _in_bounds(self, x, y):
        return (0 <= x < self.scene.W) and (0 <= y < self.scene.H)

    def _allowed(self, x, y):
        if not self._in_bounds(x,y): return False
        if self.scene.allowed_mask is None: return True
        return bool(self.scene.allowed_mask[y, x])

    def _phase_bin(self, x, y):
        n = max(self.m.theta.bins)
        if n <= 0: n = 1
        return int((x + y) % n)

    def _disc_bits(self, x, y):
        return int(((x*y) // max(1,self.scene.w_inner)) % 2)

    def _features(self, x, y):
        return (self._phase_bin(x,y), self._disc_bits(x,y))

    def _bfs_to_roi_midline(self, sx, sy):
        x0,y0,x1,y1 = self.scene.roi_bbox
        target_y = (y0+y1)//2
        max_depth = self.m.theta.max_bin()
        visited = set()
        q = [(sx,sy,0,[(sx,sy)])]
        while q:
            x,y,d,path = q.pop(0)
            if (x,y) in visited: continue
            visited.add((x,y))
            if d>0: self.m.lints.assert_no_skip(1)
            if y == target_y and x0 <= x < x1:
                return path, (x,y)
            if d >= max_depth:
                continue
            for nx,ny in self._neighbors4(x,y):
                if not self._allowed(nx,ny): continue
                if self.m.structural.contiguity and (abs(nx-x)+abs(ny-y) != 1):
                    continue
                q.append((nx,ny,d+1,path+[(nx,ny)]))
        return None, None

    def propose_candidates(self, sources: List[Tuple[int,int]], screen):
        cands = []
        for sx,sy in sources:
            out = self._bfs_to_roi_midline(sx,sy)
            if out is None or out[0] is None:
                continue
            path, (tx,ty) = out
            feats = self._features(tx,ty)
            # prospective profile: add +1 at (tx, mid) then score
            x0,y0,x1,y1 = self.scene.roi_bbox
            roi_mid = (y0+y1)//2
            screen[roi_mid, tx] += 1
            profile = profile_from_roi(self.scene, screen)
            p1,p2,t = find_peaks_trough(profile)
            score = ratio_lex_score(p1,p2,t)
            screen[roi_mid, tx] -= 1
            cands.append((path, (tx,ty), feats, score, (p1,p2,t)))
        return cands

    def accept(self, candidates):
        if not candidates:
            return None, Result(-1,(0,0,0),[],0,False,{"reason":"no candidates"})
        target_feat = candidates[0][2]
        survivors = [c for c in candidates if c[2] == target_feat]
        if self.m.cra.enabled and len(survivors) > 1:
            uniq = {}
            for c in survivors:
                uniq[tuple(c[0])] = c
            survivors = list(uniq.values())
        if not survivors:
            return None, Result(-1,(0,0,0),[],0,False,{"reason":"no survivors"})
        best = survivors[0]
        ties = [best]
        for c in survivors[1:]:
            cmpv = compare_ratio_lex(c[3], best[3])
            if cmpv > 0: best, ties = c, [c]
            elif cmpv == 0: ties.append(c)
        rng_used = False; tie_w = []; accepted = best; acc_idx = 0
        if len(ties) > 1:
            n = len(ties)
            import numpy as np
            A = np.zeros((n,n), dtype=float)
            for i in range(n):
                for j in range(n):
                    if i==j: continue
                    if abs(len(ties[i][0]) - len(ties[j][0])) <= 1:
                        A[i,j] = 1.0
            idx, w = tie_break_pf_born(A)
            accepted = ties[idx]; rng_used = True; tie_w = w.tolist(); acc_idx = idx
        (path,(tx,ty),feats,score,pkt) = accepted
        return (tx,ty), Result(acc_idx, score, tie_w, len(ties), rng_used,
                               {"candidates": len(candidates), "survivors": len(survivors)})
